import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import gymnasium as gym

from models.policy_net import ActorCriticRNN, SharedActorCriticRNN, nmActorCriticRNN

from utils import helpers as utl


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_trained_network(
    network_type,  # 'a2crnn'
    path_to_trained_network_state_dict,
    args,
    device=device
):

    if network_type == 'a2crnn':
        if args.NMd:
            network = nmActorCriticRNN(
                args=args,
                layers_before_rnn=args.layers_before_rnn,
                rnn_hidden_dim=args.rnn_hidden_dim,
                layers_after_rnn=args.layers_after_rnn,
                rnn_cell_type=args.rnn_cell_type,
                activation_function=args.policy_net_activation_function,
                initialisation_method=args.policy_net_initialization_method,
                state_dim=args.input_state_dim_for_policy,
                state_embed_dim=args.state_embed_dim,
                action_dim=args.action_dim,
                action_embed_dim=args.action_embed_dim,
                action_space_type=args.action_space_type,
                reward_dim=args.reward_dim,
                reward_embed_dim=args.reward_embed_dim,
                N_nm = args.nNM
            ).to(device)
        elif not args.shared_rnn:
            network = ActorCriticRNN(
                args=args,
                layers_before_rnn=args.layers_before_rnn,
                rnn_hidden_dim=args.rnn_hidden_dim,
                layers_after_rnn=args.layers_after_rnn,
                rnn_cell_type=args.rnn_cell_type,
                activation_function=args.policy_net_activation_function,
                initialization_method=args.policy_net_initialization_method,
                state_dim=args.input_state_dim_for_policy,
                state_embed_dim=args.state_embed_dim,
                action_dim=args.action_dim,
                action_embed_dim=args.action_embed_dim,
                action_space_type=args.action_space_type,
                reward_dim=args.reward_dim,
                reward_embed_dim=args.reward_embed_dim
            )
        else:
            network = SharedActorCriticRNN(
                args=args,
                layers_before_rnn=args.layers_before_rnn,
                rnn_hidden_dim=args.rnn_hidden_dim,
                layers_after_rnn=args.layers_after_rnn,
                rnn_cell_type=args.rnn_cell_type,
                activation_function=args.policy_net_activation_function,
                initialization_method=args.policy_net_initialization_method,
                state_dim=args.input_state_dim_for_policy,
                state_embed_dim=args.state_embed_dim,
                action_dim=args.action_dim,
                action_embed_dim=args.action_embed_dim,
                action_space_type=args.action_space_type,
                reward_dim=args.reward_dim,
                reward_embed_dim=args.reward_embed_dim
            )

    else:
        raise ValueError(f'invalid networ_type: {network_type}')
        
    network.load_state_dict(torch.load(
        path_to_trained_network_state_dict,
        map_location=device))
    network.to(device)

    return network


#########################################################
# TEST ROLLOUT
#########################################################
def rollout_one_episode_rl2(
    env,
    policy_network,
    args,
    deterministic=False
):
    '''
    rollout the trained policy_network to 
    play one episode in the env
    '''
    
    states = []
    actions = []
    rewards = []
    state_values = []
    action_log_probs = []
    all_action_logits = []
    entropies = []
    actor_hidden_states = []
    critic_hidden_states = []

    # reset the env to get an initial state
    curr_state_dict, info = env.reset()
    # input_state_for_policy: already accounted for time
    curr_state = utl.get_states_from_state_dicts(
        curr_state_dict, args.env_name, args.time_as_state
    )
    curr_state = torch.from_numpy(curr_state).float().\
        reshape((1, 1, args.input_state_dim_for_policy)).to(device)
    prev_action = torch.zeros(1, 1, args.action_dim).to(device)
    prev_reward = torch.zeros(1, 1, 1).to(device)
    # initialize ActorCriticRNN hidden states
    if args.NMd:
        rnn_prev_hidden_state = torch.zeros(1, 1, args.rnn_hidden_dim + args.nNM).to(device)
    elif not args.shared_rnn:
        actor_prev_hidden_state = torch.zeros(1, 1, args.rnn_hidden_dim).to(device)
        critic_prev_hidden_state = torch.zeros(1, 1, args.rnn_hidden_dim).to(device)
    elif args.shared_rnn:
        rnn_prev_hidden_state = torch.zeros(1, 1, args.rnn_hidden_dim).to(device)

    # rollout
    done = False
    while not done:
        states.append(curr_state.squeeze().detach().cpu().numpy())
        
        # select an action A_{t} with inputs
        # S_{t}, A_{t-1}, R_{t-1} and hidden states
        with torch.no_grad():
            if not args.shared_rnn:
                # act
                action_categorical, action_log_prob, entropy, state_value, \
                    actor_hidden_state, critic_hidden_state = \
                        policy_network.act(
                            curr_states=curr_state,
                            prev_actions=prev_action,
                            prev_rewards=prev_reward,
                            actor_prev_hidden_states=actor_prev_hidden_state,
                            critic_prev_hidden_states=critic_prev_hidden_state,
                            return_prior=False, 
                            deterministic=deterministic
                        )
                # get action_prob
                all_action_logit, _, _, _ = policy_network(
                    curr_states=curr_state, 
                    prev_actions=prev_action, 
                    prev_rewards=prev_reward,
                    actor_prev_hidden_states=actor_prev_hidden_state, 
                    critic_prev_hidden_states=critic_prev_hidden_state,
                    return_prior=False
                )
            elif args.shared_rnn:
                # act
                action_categorical, action_log_prob, entropy, state_value, \
                    rnn_hidden_state = \
                        policy_network.act(
                            curr_states=curr_state,
                            prev_actions=prev_action,
                            prev_rewards=prev_reward,
                            rnn_prev_hidden_states=rnn_prev_hidden_state,
                            return_prior=False, 
                            deterministic=deterministic
                        )
                # get action_prob
                all_action_logit, _, _, = policy_network(
                    curr_states=curr_state, 
                    prev_actions=prev_action, 
                    prev_rewards=prev_reward,
                    rnn_prev_hidden_states=rnn_prev_hidden_state,
                    return_prior=False
                )

        # perform the action A_{t} in the environment 
        # to get S_{t+1} and R_{t+1}
        next_state_dict, reward, terminated, truncated, info = env.step(
            action_categorical.squeeze().cpu().numpy()
        )
        next_state = utl.get_states_from_state_dicts(
            next_state_dict, args.env_name, args.time_as_state
        )
        next_state = torch.from_numpy(next_state).float()\
            .reshape(1, 1, args.input_state_dim_for_policy)
        action = F.one_hot(action_categorical, num_classes=args.action_dim).\
            float().reshape((1, 1, args.action_dim))
        reward = torch.from_numpy(np.array(reward)).float().reshape(1, 1, 1)

        # update for next step
        curr_state = next_state.to(device)
        prev_action = action.to(device)
        prev_reward = reward.to(device)
        if not args.shared_rnn:
            actor_prev_hidden_state = actor_hidden_state.to(device)
            critic_prev_hidden_state = critic_hidden_state.to(device)
        elif args.shared_rnn:
            rnn_prev_hidden_state = rnn_hidden_state.to(device)

        # update if the environment is done
        done = terminated or truncated
        
        actions.append(action_categorical.squeeze().detach().cpu().numpy())
        action_log_probs.append(action_log_prob.squeeze().detach().cpu().numpy())
        all_action_logits.append(all_action_logit.squeeze().detach().cpu().numpy())
        entropies.append(entropy.squeeze().detach().cpu().numpy())
        rewards.append(reward.squeeze().detach().cpu().numpy())
        state_values.append(state_value.squeeze().detach().cpu().numpy())
        if not args.shared_rnn:
            actor_hidden_states.append(actor_hidden_state.squeeze().detach().cpu().numpy())
            critic_hidden_states.append(critic_hidden_state.squeeze().detach().cpu().numpy())
        elif args.shared_rnn:
            # if shared_rnn: use actor_hidden_states to store rnn_hidden_states
            # thus critic_hidden_states will just be all zeros
            actor_hidden_states.append(rnn_hidden_state.squeeze().detach().cpu().numpy())

    env.close()

    states = np.array(states)
    actions = np.array(actions)
    action_log_probs = np.array(action_log_probs)
    all_action_logits = np.array(all_action_logits)
    entropies = np.array(entropies)
    rewards = np.array(rewards)
    state_values = np.array(state_values)
    actor_hidden_states = np.array(actor_hidden_states)
    critic_hidden_states = np.array(critic_hidden_states)

    return info, states, actions, rewards, \
        action_log_probs, all_action_logits, entropies, state_values, \
        actor_hidden_states, critic_hidden_states


#########################################################
# PERFORMANCE
#########################################################
def get_empirical_returns(
    env_name,
    args,
    encoder,  # None if rl2
    policy_network,
    num_envs=10
):
    empirical_returns = []
    for test_env_id in range(num_envs):
        test_env = gym.make(
            f'environments.bandit:{env_name}'
        )
        if args.exp_label in ['rl2', 'noisy_rl2']:
            info, _, actions, rewards, \
            _, _, _, _, \
            _, _ = rollout_one_episode_rl2(
                test_env,
                policy_network,
                args,
                deterministic=False
            )
        else:
            raise ValueError(f'incompatible model type: {args.exp_label}')

        empirical_returns.append(np.sum(rewards))
    
    empirical_returns = np.array(empirical_returns)

    return np.average(empirical_returns), np.std(empirical_returns)


def get_max_episode_return(
    args,
    env
):
    '''
    return theoretically maximum return for a BlockBandit env
    '''
    max_episode_return = 0
    if args.env_name.split('-')[0] in ['CoupledBlockDF']:
        for block_ind, block_len in enumerate(env.block_lens):
            max_episode_return += np.max(env.block_p_reward[block_ind])*block_len
    elif args.env_name.split('-')[0] in ['UncoupledBlockDF', 'RandomWalkDF']:
        max_episode_return = np.sum(np.max(env.trial_p_reward, axis=1))
    elif args.env_name.split('-')[0] in ['BlockBandit2ArmCoupledEasy']:
        for block_ind, block_len in enumerate(env.block_lens):
            max_episode_return += np.max(env.block_p_bandits[block_ind])*block_len
    else:
        raise ValueError(f'invalid env_name: {args.env_name}')
    return max_episode_return


def get_optimality_score(
    args,
    env, 
    actions
):
    '''
    calculate the frequency that a policy chooses the better alternative, 
    with equally-rewarding blocks omitted
    '''
    block_start = 0
    num_optimal_action_trials = 0
    num_sided_trials = len(actions)
    if args.env_name.split('-')[0] in ['CoupledBlockDF', 'UncoupledBlockDF', 'RandomWalkDF']:
        for trial_id, p_reward in enumerate(env.trial_p_reward):
            if p_reward[0] == p_reward[1]:
                # if equal rew prob, exclude from calculating
                num_sided_trials -= 1
            else:
                optimal_action = np.argmax(p_reward)
                num_optimal_action_trials += int(actions[trial_id]==optimal_action)

    elif args.env_name.split('-')[0] in ['BlockBandit2ArmCoupledEasy']:
        for block_ind, block_len in enumerate(env.block_lens):
            if np.max(env.block_p_bandits[block_ind]) == 0.5:
                # if equal rew prob, exclude from calculating
                num_sided_trials -= block_len
                block_start += block_len
            else:
                optimal_action = np.argmax(env.block_p_bandits[block_ind])
                num_optimal_action_trials += np.sum(actions[block_start: block_start+block_len]==optimal_action)
                block_start += block_len
    else:
        raise ValueError(f'invalid env_name: {args.env_name}')

    optimality_score = num_optimal_action_trials/ float(num_sided_trials)
    return optimality_score


#########################################################
# PLOTTING
#########################################################
def shading_blocks(
    ax, 
    args,
    info
):
    y_min, y_max = ax.get_ylim()
    block_start = 0
    if args.env_name.split('-')[0] in ['CoupledBlockDF', 'UncoupledBlockDF']:
        for block_ind in range(len(info['task_object'].block_lens)):
            if block_ind % 2 == 1:
                ax.fill_between(
                    [block_start, block_start+info['task_object'].block_lens[block_ind]],
                    y_min, y_max,
                    color='gray', alpha=0.2
                )
            block_start += info['task_object'].block_lens[block_ind]
    elif args.env_name.split('-')[0] in ['BlockBandit2ArmCoupledEasy']:
        for block_ind in range(len(info['block_lens'])):
            if block_ind % 2 == 1:
                ax.fill_between(
                    [block_start, block_start+info['block_lens'][block_ind]],
                    y_min, y_max,
                    color='gray', alpha=0.2
                )
            block_start += info['block_lens'][block_ind]
    else:
        raise ValueError(f'invalid env_name: {args.env_name}')


def visualize_policy_CoupledBlockDF_rl2(
    model_name, 
    env, 
    args, 
    info,
    actions, 
    rewards,
    action_log_probs, 
    all_action_logits, 
    entropies, 
    state_values,
    actor_hidden_states, 
    critic_hidden_states,
    running_average_window=10,
    n_hidden_units_plot=5
):

    true_trial_p_reward_colors = ['gray', 'k']

    # figure setup
    if not args.shared_rnn:
        n_rows = 7
    else:
        n_rows = 6
    
    fig, axs = plt.subplots(
        nrows=n_rows, ncols=1, 
        figsize=(12, 3*n_rows), dpi=300
    )
    fig.suptitle(f"{model_name} in {args.env_name}")


    # -- PLOT 0: actions trial view --
    ax = axs[0]

    episode_return = rewards.sum()
    max_episode_return = get_max_episode_return(args, env.unwrapped)
    optimality_score = get_optimality_score(args, env.unwrapped, actions)

    ax.set_title(
        f'Session return {episode_return}, (max {max_episode_return:.1f}), '
        f'optimality score: {optimality_score:.2f}'
    )

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1
        for rew in [0, 1]:
            events_act_rew = np.where(
                (actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [-0.3,-0.3] + [1.3,1.3]
    line_lengths = [0.2,0.4] + [0.2,0.4]
    ax.eventplot(
        events, 
        lineoffsets=line_offsets, 
        linelengths=line_lengths, 
        linewidth=1
    )

    # action running average
    actions_moving_average = np.convolve(
        np.array(actions), 
        np.ones(running_average_window), 
        mode="same") / running_average_window
    ax.plot(
        np.arange(len(actions_moving_average)), 
        actions_moving_average
    )

    # actual reward prob: action 1
    # actions_at_go_cue_action_1 = np.copy(actions)
    # actions_at_go_cue_action_1[actions!=1] = 0
    # rewards_action_1 = rewards * actions_at_go_cue_action_1
    # rewards_action_1_moving_average = np.convolve(
    #     rewards_action_1, np.ones(running_average_window), mode="same") \
    #         / running_average_window
    # ax.plot(
    #     np.arange(len(rewards_action_1_moving_average)),
    #     rewards_action_1_moving_average,
    #     c='r'
    # )

    # block reward prob: action 1
    # block reward prob: action 1

    ax.set_xlabel('trial')
    ax.set_ylabel('action (running average freq.)')
    ax.set_ylim(
        0-0.1*8, 
        1+0.1*8)
    ax.set_yticks([0, 1], [0, 1])
    
    # true reward prob
    ax = ax.twinx()
    trial_p_reward_a0 = info['task_object'].trial_p_reward[:, 0]
    trial_p_reward_a1 = info['task_object'].trial_p_reward[:, 1]
    ax.plot(
        np.arange(len(trial_p_reward_a0)),
        trial_p_reward_a0, 
        c=true_trial_p_reward_colors[0],
        label='p(r|a0)'
    )
    ax.plot(
        np.arange(len(trial_p_reward_a1)),
        trial_p_reward_a1, 
        c=true_trial_p_reward_colors[1],
        label='p(r|a1)'
    )
    ax.set_ylabel('reward prob')
    # ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(
        0-0.045*8, 
        np.sum(info['task_object'].trial_p_reward[0])+0.045*8
    )
    ax.set_yticks([0, 0.225, 0.45])
    ax.legend()

    shading_blocks(ax, args, info)


    # -- PLOT 1: action_log_prob --
    ax = axs[1]
    ax.set_title(f'Chosen action prob')

    arr = np.exp(action_log_probs)
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('prob')
    ax.set_ylim(0, 1.05)
    shading_blocks(ax, args, info)

    
    # -- PLOT 2: policy entropy --
    ax = axs[2]
    ax.set_title(f'Policy entropy')

    arr = entropies
    ax.plot(arr)

    ax.hlines(
        np.log2(args.action_dim), 
        0, len(arr), 
        linestyles='dashed',
        colors='k',
        label='max entropy (uniform)'
    )
    ax.set_xlabel('trial')
    ax.set_ylabel('entropy')
    ax.set_ylim(0, np.log2(args.action_dim)+0.1)
    ax.legend()
    shading_blocks(ax, args, info)


    # -- PLOT 3: state_values --
    ax = axs[3]
    ax.set_title(f'State value')

    arr = state_values
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('state_value')
    shading_blocks(ax, args, info)


    # -- PLOT 4: prob_a1 --
    ax = axs[4]
    ax.set_title(f'policy: p(a1)')

    arr = np.exp(all_action_logits)[:,1] /(np.exp(all_action_logits).sum(axis=1))
    ax.plot(
        np.arange(len(arr)), 
        arr, 
        label='p(a1)', 
    )

    ax.set_xlabel('trial')
    ax.set_ylabel('p(a1)')
    ax.set_ylim(
        0-0.1*2, 
        1+0.1*2)
    ax.set_yticks([0, 1], [0, 1])

    # true reward prob
    ax = ax.twinx()
    trial_p_reward_a0 = info['task_object'].trial_p_reward[:, 0]
    trial_p_reward_a1 = info['task_object'].trial_p_reward[:, 1]
    ax.plot(
        np.arange(len(trial_p_reward_a0)),
        trial_p_reward_a0, 
        c=true_trial_p_reward_colors[0],
        label='p(r|a0)'
    )
    ax.plot(
        np.arange(len(trial_p_reward_a1)),
        trial_p_reward_a1, 
        c=true_trial_p_reward_colors[1],
        label='p(r|a1)'
    )
    ax.set_ylabel('reward prob')
    # ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(
        0-0.045*2, 
        np.sum(info['task_object'].trial_p_reward[0])+0.045*2
    )
    ax.set_yticks([0, 0.225, 0.45])
    ax.legend()

    shading_blocks(ax, args, info)


    # --PLOT 5: actor hidden states --
    ax = axs[5]
    if args.shared_rnn:
        ax.set_title(f'rnn hidden states')    
    else:
        ax.set_title(f'actor hidden states')

    for plot_id in range(n_hidden_units_plot):
        arr = actor_hidden_states[:, plot_id]
        ax.plot(arr, lw=0.9)

    ax.set_xlabel('trial')
    if args.shared_rnn:
        ax.set_ylabel('rnn_hidden_state')    
    else:
        ax.set_ylabel('actor_hidden_state')
    
    shading_blocks(ax, args, info)
    

    # --PLOT 6: critic hidden states --
    if not args.shared_rnn:
        ax = axs[6]
        ax.set_title(f'critic hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = critic_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.9)

        ax.set_xlabel('trial')
        ax.set_ylabel('critic_hidden_state')
        
        shading_blocks(ax, args, info)


    # fig settings
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def visualize_policy_UncoupledBlockDF_rl2(
    model_name, 
    env, 
    args, 
    info,
    actions, 
    rewards,
    action_log_probs, 
    all_action_logits, 
    entropies, 
    state_values,
    actor_hidden_states, 
    critic_hidden_states,
    running_average_window=10,
    n_hidden_units_plot=5
):

    true_trial_p_reward_colors = ['gray', 'k']

    # figure setup
    if not args.shared_rnn:
        n_rows = 7
    else:
        n_rows = 6
    
    fig, axs = plt.subplots(
        nrows=n_rows, ncols=1, 
        figsize=(12, 3*n_rows), dpi=300
    )
    fig.suptitle(f"{model_name} in {args.env_name}")


    # -- PLOT 0: actions trial view --
    ax = axs[0]

    episode_return = rewards.sum()
    max_episode_return = get_max_episode_return(args, env.unwrapped)
    optimality_score = get_optimality_score(args, env.unwrapped, actions)

    ax.set_title(
        f'Session return {episode_return}, (max {max_episode_return:.1f}), '
        f'optimality score: {optimality_score:.2f}'
    )

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1
        for rew in [0, 1]:
            events_act_rew = np.where(
                (actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [-0.3,-0.3] + [1.3,1.3]
    line_lengths = [0.2,0.4] + [0.2,0.4]
    ax.eventplot(
        events, 
        lineoffsets=line_offsets, 
        linelengths=line_lengths, 
        linewidth=1
    )

    # action running average
    actions_moving_average = np.convolve(
        np.array(actions), 
        np.ones(running_average_window), 
        mode="same") / running_average_window
    ax.plot(
        np.arange(len(actions_moving_average)), 
        actions_moving_average
    )

    # actual reward prob: action 1
    # actions_at_go_cue_action_1 = np.copy(actions)
    # actions_at_go_cue_action_1[actions!=1] = 0
    # rewards_action_1 = rewards * actions_at_go_cue_action_1
    # rewards_action_1_moving_average = np.convolve(
    #     rewards_action_1, np.ones(running_average_window), mode="same") \
    #         / running_average_window
    # ax.plot(
    #     np.arange(len(rewards_action_1_moving_average)),
    #     rewards_action_1_moving_average,
    #     c='r'
    # )

    # block reward prob: action 1
    # block reward prob: action 1

    ax.set_xlabel('trial')
    ax.set_ylabel('action (running average freq.)')
    ax.set_ylim(
        0-0.1*8, 
        1+0.1*8
    )
    ax.set_yticks([0, 1], [0, 1])
    
    # true reward prob
    ax = ax.twinx()
    trial_p_reward_a0 = info['task_object'].trial_p_reward[:, 0]
    trial_p_reward_a1 = info['task_object'].trial_p_reward[:, 1]
    ax.plot(
        np.arange(len(trial_p_reward_a0)),
        trial_p_reward_a0, 
        c=true_trial_p_reward_colors[0],
        label='p(r|a0)'
    )
    ax.plot(
        np.arange(len(trial_p_reward_a1)),
        trial_p_reward_a1, 
        c=true_trial_p_reward_colors[1],
        label='p(r|a1)'
    )
    ax.set_ylabel('reward prob')
    # ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(
        0-0.1*8, 
        1+0.1*8
    )
    ax.set_yticks([0.1, 0.5, 0.9])
    ax.legend()


    # -- PLOT 1: action_log_prob --
    ax = axs[1]
    ax.set_title(f'Chosen action prob')

    arr = np.exp(action_log_probs)
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('prob')
    ax.set_ylim(0, 1.05)

    
    # -- PLOT 2: policy entropy --
    ax = axs[2]
    ax.set_title(f'Policy entropy')

    arr = entropies
    ax.plot(arr)

    ax.hlines(
        np.log2(args.action_dim), 
        0, len(arr), 
        linestyles='dashed',
        colors='k',
        label='max entropy (uniform)'
    )
    ax.set_xlabel('trial')
    ax.set_ylabel('entropy')
    ax.set_ylim(0, np.log2(args.action_dim)+0.1)
    ax.legend()


    # -- PLOT 3: state_values --
    ax = axs[3]
    ax.set_title(f'State value')

    arr = state_values
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('state_value')


    # -- PLOT 4: prob_a1 --
    ax = axs[4]
    ax.set_title(f'policy: p(a1)')

    arr = np.exp(all_action_logits)[:,1] /(np.exp(all_action_logits).sum(axis=1))
    ax.plot(
        np.arange(len(arr)), 
        arr, 
        label='p(a1)', 
    )

    ax.set_xlabel('trial')
    ax.set_ylabel('p(a1)')
    ax.set_ylim(
        0-0.1*2, 
        1+0.1*2
    )
    ax.set_yticks([0, 1], [0, 1])

    # true reward prob
    ax = ax.twinx()
    trial_p_reward_a0 = info['task_object'].trial_p_reward[:, 0]
    trial_p_reward_a1 = info['task_object'].trial_p_reward[:, 1]
    ax.plot(
        np.arange(len(trial_p_reward_a0)),
        trial_p_reward_a0, 
        c=true_trial_p_reward_colors[0],
        label='p(r|a0)'
    )
    ax.plot(
        np.arange(len(trial_p_reward_a1)),
        trial_p_reward_a1, 
        c=true_trial_p_reward_colors[1],
        label='p(r|a1)'
    )
    ax.set_ylabel('reward prob')
    # ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(
        0-0.1*2, 
        1+0.1*2
    )
    ax.set_yticks([0.1, 0.5, 0.9])
    ax.legend()


    # --PLOT 5: actor hidden states --
    ax = axs[5]
    if args.shared_rnn:
        ax.set_title(f'rnn hidden states')    
    else:
        ax.set_title(f'actor hidden states')

    for plot_id in range(n_hidden_units_plot):
        arr = actor_hidden_states[:, plot_id]
        ax.plot(arr, lw=0.9)

    ax.set_xlabel('trial')
    if args.shared_rnn:
        ax.set_ylabel('rnn_hidden_state')    
    else:
        ax.set_ylabel('actor_hidden_state')
    

    # --PLOT 6: critic hidden states --
    if not args.shared_rnn:
        ax = axs[6]
        ax.set_title(f'critic hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = critic_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.9)

        ax.set_xlabel('trial')
        ax.set_ylabel('critic_hidden_state')


    # fig settings
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def visualize_policy_RandomWalkDF_rl2(
    model_name, 
    env, 
    args, 
    info,
    actions, 
    rewards,
    action_log_probs, 
    all_action_logits, 
    entropies, 
    state_values,
    actor_hidden_states, 
    critic_hidden_states,
    running_average_window=10,
    n_hidden_units_plot=5
):

    true_trial_p_reward_colors = ['gray', 'k']

    # figure setup
    if not args.shared_rnn:
        n_rows = 7
    else:
        n_rows = 6
    
    fig, axs = plt.subplots(
        nrows=n_rows, ncols=1, 
        figsize=(12, 3*n_rows), dpi=300
    )
    fig.suptitle(f"{model_name} in {args.env_name}")


    # -- PLOT 0: actions trial view --
    ax = axs[0]

    episode_return = rewards.sum()
    max_episode_return = get_max_episode_return(args, env.unwrapped)
    optimality_score = get_optimality_score(args, env.unwrapped, actions)

    ax.set_title(
        f'Session return {episode_return}, (max {max_episode_return:.1f}), '
        f'optimality score: {optimality_score:.2f}'
    )

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1
        for rew in [0, 1]:
            events_act_rew = np.where(
                (actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [-0.3,-0.3] + [1.3,1.3]
    line_lengths = [0.2,0.4] + [0.2,0.4]
    ax.eventplot(
        events, 
        lineoffsets=line_offsets, 
        linelengths=line_lengths, 
        linewidth=1
    )

    # action running average
    actions_moving_average = np.convolve(
        np.array(actions), 
        np.ones(running_average_window), 
        mode="same") / running_average_window
    ax.plot(
        np.arange(len(actions_moving_average)), 
        actions_moving_average
    )

    # actual reward prob: action 1
    # actions_at_go_cue_action_1 = np.copy(actions)
    # actions_at_go_cue_action_1[actions!=1] = 0
    # rewards_action_1 = rewards * actions_at_go_cue_action_1
    # rewards_action_1_moving_average = np.convolve(
    #     rewards_action_1, np.ones(running_average_window), mode="same") \
    #         / running_average_window
    # ax.plot(
    #     np.arange(len(rewards_action_1_moving_average)),
    #     rewards_action_1_moving_average,
    #     c='r'
    # )

    # block reward prob: action 1
    # block reward prob: action 1

    ax.set_xlabel('trial')
    ax.set_ylabel('action (running average freq.)')
    ax.set_ylim(
        0-0.1*8, 
        1+0.1*8
    )
    ax.set_yticks([0, 1], [0, 1])
    
    # true reward prob
    ax = ax.twinx()
    trial_p_reward_a0 = info['task_object'].trial_p_reward[:, 0]
    trial_p_reward_a1 = info['task_object'].trial_p_reward[:, 1]
    ax.plot(
        np.arange(len(trial_p_reward_a0)),
        trial_p_reward_a0, 
        c=true_trial_p_reward_colors[0],
        label='p(r|a0)'
    )
    ax.plot(
        np.arange(len(trial_p_reward_a1)),
        trial_p_reward_a1, 
        c=true_trial_p_reward_colors[1],
        label='p(r|a1)'
    )
    ax.set_ylabel('reward prob')
    # ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(
        0-0.1*8, 
        1+0.1*8
    )
    ax.legend()


    # -- PLOT 1: action_log_prob --
    ax = axs[1]
    ax.set_title(f'Chosen action prob')

    arr = np.exp(action_log_probs)
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('prob')
    ax.set_ylim(0, 1.05)

    
    # -- PLOT 2: policy entropy --
    ax = axs[2]
    ax.set_title(f'Policy entropy')

    arr = entropies
    ax.plot(arr)

    ax.hlines(
        np.log2(args.action_dim), 
        0, len(arr), 
        linestyles='dashed',
        colors='k',
        label='max entropy (uniform)'
    )
    ax.set_xlabel('trial')
    ax.set_ylabel('entropy')
    ax.set_ylim(0, np.log2(args.action_dim)+0.1)
    ax.legend()


    # -- PLOT 3: state_values --
    ax = axs[3]
    ax.set_title(f'State value')

    arr = state_values
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('state_value')


    # -- PLOT 4: prob_a1 --
    ax = axs[4]
    ax.set_title(f'policy: p(a1)')

    arr = np.exp(all_action_logits)[:,1] /(np.exp(all_action_logits).sum(axis=1))
    ax.plot(
        np.arange(len(arr)), 
        arr, 
        label='p(a1)', 
    )

    ax.set_xlabel('trial')
    ax.set_ylabel('p(a1)')
    ax.set_ylim(
        0-0.1*2, 
        1+0.1*2
    )
    ax.set_yticks([0, 1], [0, 1])

    # true reward prob
    ax = ax.twinx()
    trial_p_reward_a0 = info['task_object'].trial_p_reward[:, 0]
    trial_p_reward_a1 = info['task_object'].trial_p_reward[:, 1]
    ax.plot(
        np.arange(len(trial_p_reward_a0)),
        trial_p_reward_a0, 
        c=true_trial_p_reward_colors[0],
        label='p(r|a0)'
    )
    ax.plot(
        np.arange(len(trial_p_reward_a1)),
        trial_p_reward_a1, 
        c=true_trial_p_reward_colors[1],
        label='p(r|a1)'
    )
    ax.set_ylabel('reward prob')
    # ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(
        0-0.1*2, 
        1+0.1*2
    )
    ax.legend()


    # --PLOT 5: actor hidden states --
    ax = axs[5]
    if args.shared_rnn:
        ax.set_title(f'rnn hidden states')    
    else:
        ax.set_title(f'actor hidden states')

    for plot_id in range(n_hidden_units_plot):
        arr = actor_hidden_states[:, plot_id]
        ax.plot(arr, lw=0.9)

    ax.set_xlabel('trial')
    if args.shared_rnn:
        ax.set_ylabel('rnn_hidden_state')    
    else:
        ax.set_ylabel('actor_hidden_state')
    

    # --PLOT 6: critic hidden states --
    if not args.shared_rnn:
        ax = axs[6]
        ax.set_title(f'critic hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = critic_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.9)

        ax.set_xlabel('trial')
        ax.set_ylabel('critic_hidden_state')


    # fig settings
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def visualize_policy_BlockBandit_rl2(
    model_name, env, args, info,
    actions, rewards,
    action_log_probs, all_action_logits, entropies, state_values,
    actor_hidden_states, critic_hidden_states,
    running_average_window=10,
    n_hidden_units_plot=5
):

    # figure setup
    if not args.shared_rnn:
        n_rows = 7
    else:
        n_rows = 6
    block_reward_prob_color = 'k'

    fig, axs = plt.subplots(
        nrows=n_rows, ncols=1, 
        figsize=(12, 3*n_rows), dpi=300
    )
    fig.suptitle(f"{model_name} in {args.env_name}: {info['reward_prob'][0]}")

    # get blocks
    rew_prob_a1 = []
    for block_idx, block_len in enumerate(info['block_lens']):
        for trial_idx in range(block_len):
            rew_prob_a1.append(info['reward_prob'][block_idx][1])


    # -- PLOT 0: actions trial view --
    ax = axs[0]

    episode_return = rewards.sum()
    max_episode_return = get_max_episode_return(args, env.unwrapped)
    optimality_score = get_optimality_score(args, env.unwrapped, actions)

    ax.set_title(f'Trial view: actions, '
                f'return {episode_return}, (max {max_episode_return:.1f}), '
                f'optimality score: {optimality_score:.2f}')

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1
        for rew in [0, 1]:
            events_act_rew = np.where(
                (actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [-0.3,-0.3] + [1.3,1.3]
    line_lengths = [0.2,0.4] + [0.2,0.4]
    ax.eventplot(events, lineoffsets=line_offsets, 
                linelengths=line_lengths, linewidth=1)

    # action running average
    actions_moving_average = np.convolve(
        np.array(actions), np.ones(running_average_window), mode="same") \
            / running_average_window
    ax.plot(
        np.arange(len(actions_moving_average)), 
        actions_moving_average
    )

    # actual reward prob: action 1
    actions_at_go_cue_action_1 = np.copy(actions)
    actions_at_go_cue_action_1[actions!=1] = 0
    rewards_action_1 = rewards * actions_at_go_cue_action_1
    rewards_action_1_moving_average = np.convolve(
        rewards_action_1, np.ones(running_average_window), mode="same") \
            / running_average_window
    # ax.plot(
    #     np.arange(len(rewards_action_1_moving_average)),
    #     rewards_action_1_moving_average,
    #     c='r'
    # )

    # block reward prob: action 1
    ax.plot(
        np.arange(len(rew_prob_a1)),
        rew_prob_a1, 
        c=block_reward_prob_color
    )
    shading_blocks(ax, args, info)

    ax.set_xlabel('trial')
    ax.set_ylabel('action')
    ax.set_yticks([0, 1], [0, 1])


    # -- PLOT 1: action_log_prob --
    ax = axs[1]
    ax.set_title(f'chosen action prob')

    arr = np.exp(action_log_probs)
    ax.plot(arr)
    
    shading_blocks(ax, args, info)

    ax.set_xlabel('trial')
    ax.set_ylabel('prob')
    ax.set_ylim(0, 1.05)

    
    # -- PLOT 2: entropy --
    ax = axs[2]
    ax.set_title(f'entropy')

    arr = entropies
    ax.plot(arr)

    shading_blocks(ax, args, info)

    ax.set_xlabel('trial')
    ax.set_ylabel('entropy')


    # -- PLOT 3: state_values --
    ax = axs[3]
    ax.set_title(f'state_value')

    arr = state_values
    ax.plot(arr)

    shading_blocks(ax, args, info)

    ax.set_xlabel('trial')
    ax.set_ylabel('state_value')


    # -- PLOT 4: prob_a1 --
    ax = axs[4]
    ax.set_title(f'prob_a1')
    colors = ['b']

    arr = np.exp(all_action_logits)[:,1] /(np.exp(all_action_logits).sum(axis=1))
    ax.plot(np.arange(len(arr)), arr, 
        label='prob_a1', color=colors[0])

    ax.set_xlabel('trial')
    ax.set_ylabel('prob_a1')
    ax.set_ylim(0, 1)

    shading_blocks(ax, args, info)

    # block reward prob: action 1
    ax = ax.twinx()
    ax.plot(
        np.arange(len(rew_prob_a1)),
        rew_prob_a1, 
        c=block_reward_prob_color
    )
    ax.set_ylabel('a1_reward_prob', c=block_reward_prob_color)
    ax.tick_params(axis='y', labelcolor=block_reward_prob_color)
    ax.set_ylim(0, 1)


    # --PLOT 5: actor hidden states --
    ax = axs[5]
    if args.shared_rnn:
        ax.set_title(f'rnn hidden states')    
    else:
        ax.set_title(f'actor hidden states')

    for plot_id in range(n_hidden_units_plot):
        arr = actor_hidden_states[:, plot_id]
        ax.plot(arr, lw=0.9)
    
    shading_blocks(ax, args, info)

    ax.set_xlabel('timestep')
    if args.shared_rnn:
        ax.set_ylabel('rnn_hidden_state')    
    else:
        ax.set_ylabel('actor_hidden_state')
    

    # --PLOT 6: critic hidden states --
    if not args.shared_rnn:
        ax = axs[6]
        ax.set_title(f'critic hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = critic_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.9)
        
        shading_blocks(ax, args, info)

        ax.set_xlabel('timestep')
        ax.set_ylabel('critic_hidden_state')


    ### fig settings
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def visualize_policy_RandomWalkBandit_rl2(
    model_name, env, args, info,
    actions, rewards,
    action_log_probs, all_action_logits, entropies, state_values,
    actor_hidden_states, critic_hidden_states,
    running_average_window=10,
    n_hidden_units_plot=5
):

    # figure setup
    if not args.shared_rnn:
        n_rows = 7
    else:
        n_rows = 6
    randomwalk_reward_prob_colors = ['dimgray', 'k']

    fig, axs = plt.subplots(
        nrows=n_rows, ncols=1, 
        figsize=(12, 3*n_rows), dpi=300
    )
    fig.suptitle(f"{model_name} in {args.env_name}")


    # -- PLOT 0: actions trial view --
    ax = axs[0]
    episode_return = rewards.sum()
    ax.set_title(
        f'Trial view: actions, '
        f'return {episode_return}'
        )

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1
        for rew in [0, 1]:
            events_act_rew = np.where(
                (actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [-0.3,-0.3] + [1.3,1.3]
    line_lengths = [0.2,0.4] + [0.2,0.4]
    ax.eventplot(events, lineoffsets=line_offsets, 
                linelengths=line_lengths, linewidth=1)

    # action running average
    actions_moving_average = np.convolve(
        np.array(actions), np.ones(running_average_window), mode="same") \
            / running_average_window
    ax.plot(
        np.arange(len(actions_moving_average)), 
        actions_moving_average
    )

    # random walk reward prob
    for bandit_ind in range(2):
        arr = info['reward_prob'][:, bandit_ind]
        ax.plot(
            np.arange(len(arr)), arr,
            c=randomwalk_reward_prob_colors[bandit_ind],
            label=f'rw_arm_{bandit_ind}'
        )

    ax.set_xlabel('trial')
    ax.set_ylabel('action')
    ax.set_yticks([0, 1], [0, 1])
    ax.legend()


    # -- PLOT 1: action_log_prob --
    ax = axs[1]
    ax.set_title(f'chosen action prob')

    arr = np.exp(action_log_probs)
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('prob')
    ax.set_ylim(0, 1.05)

    
    # -- PLOT 2: entropy --
    ax = axs[2]
    ax.set_title(f'entropy')

    arr = entropies
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('entropy')


    # -- PLOT 3: state_values --
    ax = axs[3]
    ax.set_title(f'state_value')

    arr = state_values
    ax.plot(arr)

    ax.set_xlabel('trial')
    ax.set_ylabel('state_value')


    # -- PLOT 4: pred_reward_probs_a1 & prob_a1 --
    ax = axs[4]
    ax.set_title(f'pred_reward_probs & prob_a1')
    colors = ['c', 'b']

    arr = np.exp(all_action_logits)[:,1] /(np.exp(all_action_logits).sum(axis=1))
    ax.plot(np.arange(len(arr)), arr, 
        label='prob_a1', color=colors[1])

    ax.set_xlabel('trial')
    ax.set_ylabel('pred_reward_probs/ prob_a1')
    ax.set_ylim(0, 1)

    ax.legend()

    ax = ax.twinx()
    # random walk reward prob 1
    ax = ax.twinx()
    arr = info['reward_prob'][:, 1]
    ax.plot(
        np.arange(len(arr)), arr, 
        c=randomwalk_reward_prob_colors[bandit_ind],
        label=f'rw_arm_{bandit_ind}'
    )
    ax.set_ylabel('a1_reward_prob', c=randomwalk_reward_prob_colors[bandit_ind])
    ax.tick_params(axis='y', labelcolor=randomwalk_reward_prob_colors[bandit_ind])
    ax.set_ylim(0, 1)


    # --PLOT 5: actor hidden states --
    ax = axs[5]
    if args.shared_rnn:
        ax.set_title(f'rnn hidden states')    
    else:
        ax.set_title(f'actor hidden states')

    for plot_id in range(n_hidden_units_plot):
        arr = actor_hidden_states[:, plot_id]
        ax.plot(arr, lw=0.9)

    ax.set_xlabel('timestep')
    if args.shared_rnn:
        ax.set_ylabel('rnn_hidden_state')    
    else:
        ax.set_ylabel('actor_hidden_state')

    
    # --PLOT 6: critic hidden states --
    if not args.shared_rnn:
        ax = axs[6]
        ax.set_title(f'critic hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = critic_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.9)

        ax.set_xlabel('timestep')
        ax.set_ylabel('critic_hidden_state')


    ### fig settings
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def visualize_policy_TimedBandit(
    model_name, env, args, info,
    go_cues, actions, rewards,
    action_log_probs, entropies, state_values,
    actor_hidden_states, critic_hidden_states,
    running_average_window=10,
    n_hidden_units_plot=5
):
    if not args.shared_rnn:
        n_rows = 7
    else:
        n_rows = 6
    fig, axs = plt.subplots(
        nrows=n_rows, ncols=1, 
        figsize=(12, 3*n_rows), dpi=300
    )
    fig.suptitle(f"{model_name} in {args.env_name}: {info['reward_prob'][0]}")


    ### PLOT 0: actions trial view, i.e. at go_cue
    ax = axs[0]
    actions_at_go_cue = actions[np.where(go_cues.astype(int))[0]]
    rewards_at_go_cue = rewards[np.where(go_cues.astype(int))[0]]

    episode_return = rewards.sum()
    max_episode_return = get_max_episode_return(args, env.unwrapped)
    optimality_score = get_optimality_score(args, env.unwrapped, actions_at_go_cue)

    ax.set_title(f'Trial view: actions at go_cue, '
                f'return {episode_return}, (max {max_episode_return:.1f}), '
                f'optimality score: {optimality_score:.2f}')

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1, 2: withhold 
        for rew in [0, 1]:
            events_act_rew = np.where((actions_at_go_cue==act) & 
                                      (rewards_at_go_cue==rew))[0]
            events.append(events_act_rew)
    line_offsets = [-0.2,-0.2] + [1.2,1.2] + [-1,-1]
    line_lengths = [0.2,0.4] + [0.2,0.4] + [0.2,0.4]
    ax.eventplot(events, lineoffsets=line_offsets, 
                linelengths=line_lengths, linewidth=1)

    # action running average
    actions_at_go_cue_withhold = np.copy(actions_at_go_cue)  # for numerical correction
    actions_at_go_cue_withhold[np.where(actions_at_go_cue!=2)[0]] = 0  # set non-withhold actions to 0
    actions_moving_average = np.convolve(
        np.array(actions_at_go_cue), np.ones(running_average_window), mode="same") \
            / running_average_window
    actions_moving_average_correction = np.convolve(
        np.array(actions_at_go_cue_withhold), np.ones(running_average_window), mode="same") \
            / running_average_window
    actions_moving_average -= actions_moving_average_correction
    ax.plot(
        np.arange(len(actions_moving_average)), 
        actions_moving_average
    )

    # block reward prob: action 1
    rew_prob_a1 = []
    for block_idx, block_len in enumerate(info['block_lens']):
        for trial_idx in range(block_len):
            rew_prob_a1.append(info['reward_prob'][block_idx][1])
    ax.plot(
        np.arange(len(rew_prob_a1)),
        rew_prob_a1, 
        c='m'
    )

    # actual reward prob: action 1
    actions_at_go_cue_action_1 = np.copy(actions_at_go_cue)
    actions_at_go_cue_action_1[actions_at_go_cue!=1] = 0
    rewards_action_1 = rewards_at_go_cue * actions_at_go_cue_action_1
    rewards_action_1_moving_average = np.convolve(
        rewards_action_1, np.ones(running_average_window), mode="same") \
            / running_average_window
    ax.plot(
        np.arange(len(rewards_action_1_moving_average)),
        rewards_action_1_moving_average,
        c='r'
    )

    ax.set_xlabel('trial (at go_cue)')
    ax.set_ylabel('action')
    ax.set_yticks(np.arange(-1, -1+args.action_dim), ['withhold', 0, 1])


    ### PLOT 1: actions, timestep view
    ax = axs[1]
    ax.set_title(f'Timestep view: actions')

    # event raster
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1, 2: withhold
        for rew in [0, 1]:
            events_act_rew = np.where((actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [0,0] + [1,1] + [-1,-1]
    line_lengths = [0.3,0.6] + [0.3,0.6] + [0.3,0.6]
    ax.eventplot(events, lineoffsets=line_offsets, 
                linelengths=line_lengths, linewidth=0.5)
    # event raster: for incorrectly timed actions
    events = []
    for act in range(args.action_dim):
        # 0: action 0, 1: action 1, 2: withhold
        for rew in [-1]:  # incorrectly timed
            events_act_rew = np.where((actions==act) & (rewards==rew))[0]
            events.append(events_act_rew)
    line_offsets = [0] + [1] + [-1]
    line_lengths = [0.5] + [0.5] + [0.5]
    ax.eventplot(events, lineoffsets=line_offsets, 
                linelengths=line_lengths, linewidth=0.7, color='r')

    # plot go_cues
    go_cue_timesteps = np.where(go_cues==1)[0]
    ax.vlines(go_cue_timesteps, -1.5, 1.5, 'k', linestyle='--', lw=0.2)

    ax.set_xlabel('timestep')
    ax.set_ylabel('action')
    ax.set_yticks(np.arange(-1, -1+args.action_dim), ['withhold', 0, 1])


    ### PLOT 2: action_log_prob
    ax = axs[2]
    ax.set_title(f'chosen action prob')

    arr = np.exp(action_log_probs)
    ax.plot(arr, lw=0.6)

    # plot go_cues
    go_cue_timesteps = np.where(go_cues==1)[0]
    ax.vlines(go_cue_timesteps, arr.min(), arr.max(), 'k', linestyle='--', lw=0.2)

    ax.set_xlabel('timestep')
    ax.set_ylabel('prob')
    ax.set_ylim(0, 1.05)


    ### PLOT 3: entropy
    ax = axs[3]
    ax.set_title(f'entropy')

    arr = entropies
    ax.plot(arr, lw=0.6)

    # plot go_cues
    go_cue_timesteps = np.where(go_cues==1)[0]
    ax.vlines(go_cue_timesteps, arr.min(), arr.max(), 'k', linestyle='--', lw=0.2)

    ax.set_xlabel('timestep')
    ax.set_ylabel('entropy')


    ### PLOT 4: state_values
    ax = axs[4]
    ax.set_title(f'state_value')

    arr = state_values
    ax.plot(arr, lw=0.6)

    # plot go_cues
    go_cue_timesteps = np.where(go_cues==1)[0]
    ax.vlines(go_cue_timesteps, arr.min(), arr.max(), 'k', linestyle='--', lw=0.2)

    ax.set_xlabel('timestep')
    ax.set_ylabel('state_value')


    ### PLOT 5: actor/ shared hidden states
    ax = axs[5]
    if not args.shared_rnn:
        ax.set_title(f'actor hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = actor_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.6)

        # plot go_cues
        go_cue_timesteps = np.where(go_cues==1)[0]
        ax.vlines(go_cue_timesteps, -1.5, 1.5, 'k', linestyle='--', lw=0.2)

        ax.set_xlabel('timestep')
        ax.set_ylabel('actor_hidden_state')
    
    else:
        ax.set_title(f'shared rnn hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = actor_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.6)

        # plot go_cues
        go_cue_timesteps = np.where(go_cues==1)[0]
        ax.vlines(go_cue_timesteps, -1.5, 1.5, 'k', linestyle='--', lw=0.2)

        ax.set_xlabel('timestep')
        ax.set_ylabel('shared_rnn_hidden_state')


    ### PLOT 6: critic hidden states
    if not args.shared_rnn:
        ax = axs[6]
        ax.set_title(f'critic hidden states')

        for plot_id in range(n_hidden_units_plot):
            arr = critic_hidden_states[:, plot_id]
            ax.plot(arr, lw=0.6)

        # plot go_cues
        go_cue_timesteps = np.where(go_cues==1)[0]
        ax.vlines(go_cue_timesteps, -1.5, 1.5, 'k', linestyle='--', lw=0.2)

        ax.set_xlabel('timestep')
        ax.set_ylabel('critic_hidden_state')


    ### fig settings
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def plot_rnn_hidden_states(
    rnn_hidden_states, rnn_type,
    n_hidden_units_plot=64,
    timestep2plot_i=0, timestep2plot_f=100,
):
    '''
    --
    rnn_type: str, one of [actor, critic, shared_rnn]
    '''
    n_cols = int(np.sqrt(n_hidden_units_plot))
    n_rows = int(n_hidden_units_plot/n_cols) + 1 
    
    fig, axs = plt.subplots(
        nrows=n_rows, ncols=n_cols, 
        figsize=(20, 20), dpi=300
    )
    fig.suptitle(f'{rnn_type} hidden states')

    # print(f'rnn_hidden_states: {rnn_hidden_states.shape}')
    for hidden_id in range(n_hidden_units_plot):
        row_id = hidden_id // n_cols
        col_id = hidden_id % n_cols
        ax = axs[row_id][col_id]
        ax.set_title(f'hidden {hidden_id}')
        ax.plot(
            np.arange(timestep2plot_i, timestep2plot_f),
            rnn_hidden_states[timestep2plot_i: timestep2plot_f, hidden_id]
        )
    
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig


def get_rnn_connectivity(policy_net, rnn_type):
    '''
    --
    rnn_type: str, one of [actor, critic, shared_rnn]
    '''
    rnn_weights = {}
    if rnn_type == 'shared_rnn':
        for name in policy_net.shared_rnn.named_parameters():
            if 'weight' in name[0]:
                # print(name[0], name[1].shape)
                rnn_weights[name[0]] = name[1].detach().cpu().numpy()
        for name in policy_net.actor_output.named_parameters():
            if 'weight' in name[0]:
                # print(name[0], name[1].shape)
                rnn_weights['actor_weight_ho_l0'] = name[1].detach().cpu().numpy()
        for name in policy_net.critic_output.named_parameters():
            if 'weight' in name[0]:
                # print(name[0], name[1].shape)
                rnn_weights['critic_weight_ho_l0'] = name[1].detach().cpu().numpy()
    else:
        if rnn_type == 'actor':
            rnn = policy_net.actor_rnn
            output = policy_net.actor_output
        elif rnn_type == 'critic':
            rnn = policy_net.critic_rnn
            output = policy_net.critic_output
        for name in rnn.named_parameters():
            if 'weight' in name[0]:
                # print(name[0], name[1].shape)
                rnn_weights[name[0]] = name[1].detach().cpu().numpy()
        for name in output.named_parameters():
            if 'weight' in name[0]:
                # print(name[0], name[1].shape)
                rnn_weights['weight_ho_l0'] = name[1].detach().cpu().numpy()
    
    return rnn_weights


def plot_rnn_connectivity(
    rnn_weights, rnn_type,
    args,
    cmap='bwr'
):
    '''
    --
    rnn_type: str, one of [actor, critic, shared_rnn]
    '''
    input_dim = args.input_state_dim_for_policy + args.action_dim + args.reward_dim
    if rnn_type == 'shared_rnn':
        n_cols = 4
        width_ratios = [
            input_dim,
            args.rnn_hidden_dim,
            args.action_dim,
            args.reward_dim
        ]
    elif rnn_type == 'actor':
        n_cols = 3
        width_ratios = [
            input_dim,
            args.rnn_hidden_dim,
            args.action_dim
        ]
    elif rnn_type == 'critic':
        n_cols = 3
        width_ratios = [
            input_dim,
            args.rnn_hidden_dim,
            args.reward_dim
        ]
    
    fig, axs = plt.subplots(
        nrows=1, ncols=n_cols,
        gridspec_kw={'width_ratios': width_ratios},
        figsize=(14, 6), dpi=300
    )
    plt.set_cmap(cmap)
    fig.suptitle(f'{rnn_type} connectivity')

    # ih_weights
    ax = axs[0]
    ax.set_title(f'weight_ih')
    ih_weights = rnn_weights['weight_ih_l0']#[idxs, :]
    im0 = ax.imshow(ih_weights)
    fig.colorbar(im0, ax=axs[0])

    # hh_weights
    ax = axs[1]
    ax.set_title(f'weight_hh')
    hh_weights = rnn_weights['weight_hh_l0']#[:, idxs]
    im1 = ax.imshow(hh_weights)
    fig.colorbar(im1, ax=axs[1])

    # ho_weights
    if rnn_type in ['actor', 'critic']:
        ax = axs[2]
        ax.set_title(f'ho')
        ho_weights = rnn_weights['weight_ho_l0']#[:, idxs]
        im2 = ax.imshow(ho_weights.T)
        fig.colorbar(im2, ax=axs[2])

    else:
        ax = axs[2]
        ax.set_title(f'actor_ho')
        ho_weights = rnn_weights['actor_weight_ho_l0']#[:, idxs]
        im2 = ax.imshow(ho_weights.T)
        fig.colorbar(im2, ax=axs[2])

        ax = axs[3]
        ax.set_title(f'critic_ho')
        ho_weights = rnn_weights['critic_weight_ho_l0']#[:, idxs]
        im3 = ax.imshow(ho_weights.T)
        fig.colorbar(im3, ax=axs[3])

    fig.tight_layout(rect=[0, 0.02, 1, 0.98])

    return fig